# author zybz 2022.3.16
# Step 2
import xml.etree.ElementTree as ET
import os


class get_labels():
    def __init__(self, classes, sets=["train", "test", "val"]):
        if classes is None:
            print("Please input classes")
            return
        self.classes = classes
        self.sets = sets
        self.run()
        pass

    def convert(self, size, box):
        dw = 1.0 / size[0]
        dh = 1.0 / size[1]
        x = (box[0] + box[1]) / 2.0
        y = (box[2] + box[3]) / 2.0
        w = box[1] - box[0]
        h = box[3] - box[2]
        x = x * dw
        w = w * dw
        y = y * dh
        h = h * dh
        return (x, y, w, h)

    def convert_annotation(self, image_set, image_id):
        in_file = open("../Annotations/%s.xml" % (image_id), encoding='utf-8')
        out_file = open("../labels/%s.txt" % (image_id), "w", encoding='utf-8')
        tree = ET.parse(in_file)
        root = tree.getroot()
        size = root.find("size")
        w = int(size.find("width").text)
        h = int(size.find("height").text)

        for obj in root.iter("object"):
            difficult = obj.find("difficult").text
            cls = obj.find("name").text
            if cls not in self.classes or int(difficult) == 1:
                continue
            cls_id = self.classes.index(cls)
            xmlbox = obj.find("bndbox")
            b = (
                float(xmlbox.find("xmin").text),
                float(xmlbox.find("xmax").text),
                float(xmlbox.find("ymin").text),
                float(xmlbox.find("ymax").text),
            )
            bb = self.convert((w, h), b)
            out_file.write(
                str(cls_id) + " " + " ".join([str(a) for a in bb]) + "\n")

    def run(self):
        for image_set in self.sets:
            if not os.path.exists("../labels/"):
                os.makedirs("../labels/")
            image_ids = open("../ImageSets/%s.txt" %
                             (image_set)).read().strip().split()
            list_file = open("../%s.txt" % (image_set), "w")
            for image_id in image_ids:
                datasets_name = os.getcwd().split('\\')[-2]  # 获取该数据集文件夹的名字
                relative_path = '../datasets/' + datasets_name  # 该路径用于train.py中
                list_file.write(relative_path + "/images/%s.jpg\n" %
                                (image_id))
                self.convert_annotation(image_set, image_id)
            list_file.close()
        print("Step 2 make labels successfully")

    pass


if __name__ == "__main__":
    demo = get_labels(classes=['fire, fall'])